{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import mediapy as media\n",
    "import mink\n",
    "import mujoco\n",
    "import numpy as np\n",
    "\n",
    "from mujoco_playground.locomotion.go1 import go1_constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rz(phi: np.ndarray, swing_height=0.08) -> np.ndarray:\n",
    "  def cubic_bezier_interpolation(y_start, y_end, x):\n",
    "    y_diff = y_end - y_start\n",
    "    bezier = x**3 + 3 * (x**2 * (1 - x))\n",
    "    return y_start + y_diff * bezier\n",
    "\n",
    "  # Convert [-pi, pi] to [0, 1].\n",
    "  x = (phi + np.pi) / (2 * np.pi)\n",
    "  return np.where(\n",
    "      x <= 0.5,\n",
    "      cubic_bezier_interpolation(0, swing_height, 2 * x),\n",
    "      cubic_bezier_interpolation(swing_height, 0, 2 * x - 1),\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phi = np.linspace(-np.pi, np.pi, 200)\n",
    "rz = get_rz(phi)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(phi, rz)\n",
    "plt.axhline(y=0.08, color=\"r\", linestyle=\"--\", label=\"nominal swing height\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "freq = 0.5\n",
    "duration = 1.0\n",
    "ctrl_freq = 60\n",
    "gait = \"walk\"  # [\"walk\", \"trot\", \"pace\", \"bound\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctrl_dt = 1.0 / ctrl_freq\n",
    "dt = 2 * np.pi / freq * ctrl_dt\n",
    "n_steps = int(duration / ctrl_dt)\n",
    "\n",
    "gait_phases = {\n",
    "    \"walk\": np.array([0, 0.5 * np.pi, np.pi, 1.5 * np.pi]),\n",
    "    \"trot\": np.array([0, np.pi, np.pi, 0]),\n",
    "    \"pace\": np.array([0, np.pi, 0, np.pi]),\n",
    "    \"bound\": np.array([0, 0, np.pi, np.pi]),\n",
    "    \"gallop\": np.array([0, 0, 0, 0]),\n",
    "}\n",
    "\n",
    "phase_shifts = gait_phases[gait]\n",
    "phases = np.zeros((n_steps, 4))\n",
    "rs = np.zeros((n_steps, 4))\n",
    "t = 0\n",
    "for i in range(int(duration / ctrl_dt)):\n",
    "  t += dt\n",
    "  phases[i] = np.fmod(phase_shifts + t + np.pi, 2 * np.pi) - np.pi\n",
    "  rs[i] = get_rz(phases[i])\n",
    "\n",
    "# plt.plot(np.cos(phases[:, 0]), label=\"FR\")\n",
    "# plt.plot(np.cos(phases[:, 1]), label=\"FL\")\n",
    "# plt.plot(np.cos(phases[:, 2]), label=\"RR\")\n",
    "# plt.plot(np.cos(phases[:, 3]), label=\"RL\")\n",
    "# plt.legend()\n",
    "# plt.show()\n",
    "\n",
    "# plt.plot(rs[:, 0], label=\"FR\")\n",
    "# plt.plot(rs[:, 1], label=\"FL\")\n",
    "# plt.plot(rs[:, 2], label=\"RR\")\n",
    "# plt.plot(rs[:, 3], label=\"RL\")\n",
    "# plt.show()\n",
    "\n",
    "model = mujoco.MjModel.from_xml_path(str(go1_constants.FEET_ONLY_XML))\n",
    "configuration = mink.Configuration(model)\n",
    "feet = [\"FR\", \"FL\", \"RR\", \"RL\"]\n",
    "\n",
    "base_task = mink.FrameTask(\n",
    "    frame_name=\"trunk\",\n",
    "    frame_type=\"body\",\n",
    "    position_cost=1.0,\n",
    "    orientation_cost=1.0,\n",
    ")\n",
    "\n",
    "posture_task = mink.PostureTask(model, cost=1e-5)\n",
    "\n",
    "feet_tasks = []\n",
    "for foot in feet:\n",
    "  task = mink.FrameTask(\n",
    "      frame_name=foot,\n",
    "      frame_type=\"site\",\n",
    "      position_cost=1.0,\n",
    "      orientation_cost=0.0,\n",
    "  )\n",
    "  feet_tasks.append(task)\n",
    "\n",
    "tasks = [base_task, posture_task, *feet_tasks]\n",
    "\n",
    "model = configuration.model\n",
    "data = configuration.data\n",
    "solver = \"quadprog\"\n",
    "\n",
    "configuration.update_from_keyframe(\"home_higher\")\n",
    "posture_task.set_target_from_configuration(configuration)\n",
    "base_task.set_target_from_configuration(configuration)\n",
    "\n",
    "# Get current foot positions.\n",
    "feet_positions = []\n",
    "for foot in feet:\n",
    "  feet_positions.append(data.site(foot).xpos.copy())\n",
    "feet_positions = np.array(feet_positions)\n",
    "\n",
    "scene_option = mujoco.MjvOption()\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
    "\n",
    "frames = []\n",
    "with mujoco.Renderer(model, height=480, width=640) as renderer:\n",
    "  # Assign foot heights as targets.\n",
    "  for r in rs:\n",
    "    for i, foot in enumerate(feet):\n",
    "      foot_pos = feet_positions[i].copy()\n",
    "      foot_pos[-1] = r[i]\n",
    "      feet_tasks[i].set_target(mink.SE3.from_translation(foot_pos))\n",
    "\n",
    "    vel = mink.solve_ik(configuration, tasks, ctrl_dt, solver, 1e-5)\n",
    "    configuration.integrate_inplace(vel, ctrl_dt)\n",
    "    mujoco.mj_forward(model, data)\n",
    "\n",
    "    renderer.update_scene(data, camera=\"side\", scene_option=scene_option)\n",
    "    frames.append(renderer.render())\n",
    "media.show_video(frames, fps=(1.0 / ctrl_dt), loop=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
